#!/usr/bin/env python

# make it also work with python 3
from __future__ import print_function

import sys
import os

if sys.version_info < (2, 7, 5):
    # allow the environment to override a too-old python in PATH
    if os.environ.get('UPCXX_PYTHON') and not os.environ.get('UPCXX_PYTHON_NORECURSE'):
        os.environ['UPCXX_PYTHON_NORECURSE'] = "1";
        os.execv('/usr/bin/env', ['/usr/bin/env',os.environ['UPCXX_PYTHON']] + sys.argv)
    sys.stderr.write('ERROR: Python version >= 2.7.5 required. Please set $UPCXX_PYTHON to a working interpreter.\n')
    exit(1)
if os.environ.get('UPCXX_PYTHON_NORECURSE'):
    del os.environ['UPCXX_PYTHON_NORECURSE']

import argparse
import subprocess
import string
import re
import copy

INT_HEAPSZ_UNITS = ['m', 'mb', 'k','kb']
FLOAT_HEAPSZ_UNITS = ['g', 'gb', 't','tb']
MAX_HEAPSZ_UNITS_SPECIAL = [r'%']
MAX_HEAPSZ_UNITS = MAX_HEAPSZ_UNITS_SPECIAL
ONE_KB = 1024
ONE_MB = ONE_KB * ONE_KB
ONE_GB = ONE_MB * ONE_KB
ONE_TB = ONE_GB * ONE_KB


# stores the arg parser usage string
_parser = None

class SegmentStrategy():
    def __init__(self, num, strategy):
        self.num = num
        self.strategy = strategy

def ranks_type(a):
    try:
        v = int(a)
    except ValueError:
        raise argparse.ArgumentTypeError('invalid format (' + a + ') - must be an integer')
    if v <= 0:
        raise argparse.ArgumentTypeError('number of processes must be an integer > 0')
    return v

def nodes_type(a):
    try:
        v = int(a)
    except ValueError:
        raise argparse.ArgumentTypeError('invalid format (' + a + ') - must be an integer')
    if v <= 0:
        raise argparse.ArgumentTypeError('number of nodes must be an integer > 0')
    return v

def upcxx_memstr_parse(s):
    mmax = re.search('^MAX$', s, re.IGNORECASE)
    mint = re.search('^\s*(\d+)\s*$', s, re.IGNORECASE)
    if mmax:
        return SegmentStrategy(100,'%')
    elif mint:
        num = int(mint.group(1))
        return SegmentStrategy(num * ONE_MB,'')
    else:
        return memstr_parse(s)

def memstr_parse(s):
    max_heapsz_units_special = '|'.join(MAX_HEAPSZ_UNITS_SPECIAL)
    int_heapsz_units = '|'.join(INT_HEAPSZ_UNITS)
    float_heapsz_units = '|'.join(FLOAT_HEAPSZ_UNITS)

    mfloat = re.search('^\s*(([1-9]\d*|0)(\.\d+)?)\s*(' + float_heapsz_units + ')\s*$', s, re.IGNORECASE)
    mint = re.search('^\s*([1-9]\d*)\s*('+ int_heapsz_units +')\s*$', s, re.IGNORECASE)
    m4 = re.search('^\s*(([1-9]\d*|0)(\.\d+)?)\s*('+ max_heapsz_units_special + ')\s*$', s, re.IGNORECASE)

    if not mfloat and not mint and not m4:
        return None

    if mfloat:
        num = float(mfloat.group(1))
        if 'T' in mfloat.group(4).upper():
            num *= ONE_TB
        elif 'TB' in mfloat.group(4).upper():
            num *= ONE_TB
        else:
            num *= ONE_GB

        num=int(num)
        strat = ''
    elif mint:
        num = int(mint.group(1))
        #if mint.group(2):
        if 'K' in mint.group(2).upper():
            num *= ONE_KB
        elif 'KB' in mint.group(2).upper():
            num *= ONE_KB
        else:
            num *= ONE_MB

        num=int(num)
        strat = ''
    elif m4:
        num=float(m4.group(1))/100.0
        strat = str(m4.group(4))   

    return SegmentStrategy(num,strat)
 

def heapsz_type(a):
    strat = memstr_parse(a)
    if strat is None:
        raise argparse.ArgumentTypeError('shared heap must be an integer > 0 and have a suffix of one of: ' + \
                                         ', '.join(INT_HEAPSZ_UNITS) + ', a real number > 0 and have a suffix of one of: ' + \
                                         ', '.join(FLOAT_HEAPSZ_UNITS) + ', or be a fraction of the main memory available on one node and have a suffix of one of: ' + \
                                         ', '.join(MAX_HEAPSZ_UNITS))
    return strat

    
def load_args():
    global _parser
    _parser = argparse.ArgumentParser(description='A portable parallel job launcher for UPC++ programs, v2021.3.0')
    _parser.add_argument('-n', '-np', dest='nranks', type=ranks_type, default=None, metavar='NUM',
                         help='Spawn NUM number of UPC++ processes. Required.')
    _parser.add_argument('-N', dest='nnodes', type=nodes_type, default=None, metavar='NUM',
                         help='Run on NUM of nodes.')
    _parser.add_argument('-shared-heap', dest='heapsz', type=heapsz_type,
                         help='Requests HEAPSZ size of shared memory per process. ' + \
                             'HEAPSZ must include a unit suffix matching the pattern "[KMGT]B?" or be "[0-100]%%" (case-insensitive).')
    _parser.add_argument('-backtrace', dest='backtrace', default=False, action='store_true',
                         help='Enable backtraces. Compile with -g for full debugging information.')
    _parser.add_argument('-show', dest='show', default=False, action='store_true',
                         help='Testing: don\'t start the job, just output the command line that would ' + \
                        'have been executed')
    _parser.add_argument('-info', dest='info', default=False, action='store_true',
                         help='Display useful information about the executable')
    _parser.add_argument('-ssh-servers', dest='sshservers', type=str, default=None, metavar='HOSTS',
                         help='List of SSH servers, comma separated.')
    _parser.add_argument('-localhost', dest='localhost', default=False, action='store_true',
                         help='Run UDP-conduit program on local host only')
    _parser.add_argument('-v', dest='verbose', default=0, action='count',
                         help='Generate verbose output. Multiple applications increase verbosity.')
    _parser.add_argument('-E', dest='env_list', default=[], action='append', metavar='VAR1[,VAR2]',
                         help='Adds provided arguments to the list of environment variables to propagate to compute processes.')
    cmd_arg_group = _parser.add_argument_group('command to execute')
    cmd_arg_group.add_argument('cmd', default=None, metavar='command', help='UPC++ executable')
    cmd_arg_group.add_argument('cmd_args', default=[], nargs=argparse.REMAINDER, metavar='...',
                               help='arguments')

    _parser._optionals.title = 'options'
    return _parser.parse_args()


def print_error_and_abort(err):
    _parser.print_usage()
    print('\nError: ' + err, file=sys.stderr)
    sys.exit(1)
    

def is_printable(text, printables=""):
    return set(str(text)) - set(string.printable + printables) == set()


def print_env_info(command):
    try:
        if sys.version_info[0] == 2:
            f = open(command, 'r')
        else:
            f = open(command, 'r', encoding='ISO-8859-1')
        for line in f:
            for field in line.split('$'):
                if not is_printable(field) or len(field) <= 1:
                    continue
                if field.startswith('GASNET') or field.startswith('GASNet') or field.startswith('UPCXX') or \
                   field.startswith('upcxx') or field.startswith('AMUDP'):
                    if ':' in field and not '%' in field:
                        print(field.strip())
        f.close()
    except IOError as err:
        print('Error: ' + err.strerror + ': "' + command + '"')
        sys.exit(1)


def get_key_in_exec(command, key):
    try:
        if sys.version_info[0] == 2:
            f = open(command, 'r')
        else:
            f = open(command, 'r', encoding='ISO-8859-1')
        for line in f:
            for field in line.split('$'):
                if field.startswith(key):
                    return field[len(key):].strip()
        f.close()
    except IOError as err:
        print('Error: ' + err.strerror + ': "' + command + '"')
        sys.exit(1)
    return None
    

def get_conduit(command):
    conduit = get_key_in_exec(command, 'GASNetConduitName:')
    return None if not conduit else conduit.lower()


def set_ssh_servers(ssh_servers, nranks, conduit):
    split_chars = ',|/|;|:| '
    if conduit != 'udp':
        split_chars += '|\t|\n|\r'
    if ssh_servers:
        num_servers = len(re.split(split_chars, ssh_servers))
    elif 'GASNET_SSH_SERVERS' in os.environ:
        num_servers = len(re.split(split_chars, os.environ['GASNET_SSH_SERVERS']))
    elif 'SSH_SERVERS' in os.environ:
        num_servers = len(re.split(split_chars, os.environ['SSH_SERVERS']))
    else:
        num_servers = 0

    if conduit == 'udp' and num_servers > 0 and num_servers < nranks:
        print_error_and_abort('For UDP conduit, need to specify at least one host per process in ' + \
                              '-ssh-servers or environment variable GASNET_SSH_SERVERS, or use -localhost parameter')
        
    # FIXME: this leaves all the checking to gasnet. This has the disadvantage that the error
    # messsages will only refer to environment variables as solutions, and not mention parameters
    # to upcxx-run. Hopefully we can make this more friendly in future.
    if num_servers == 0:
        return
    
    if ssh_servers:
        os.environ['SSH_SERVERS'] = ssh_servers
        os.environ['GASNET_SSH_SERVERS'] = ssh_servers
        os.environ['GASNET_SPAWNFN'] = 'S'
        os.environ['GASNET_IBV_SPAWNFN'] = 'ssh'

    
def main():
    args = load_args()
    
    def print_verbose(*a, **kwa):
        if args.verbose:
            print( " ".join(map(str,a)), **kwa)

    def print_info(*a, **kwargs):
        if args.info:
            print( " ".join(map(str,a)), **kwa)


    if not args.info and not args.nranks:
        print_error_and_abort('Must specify the number of processes, -n or -np')


    cmd = [args.cmd] + args.cmd_args
    is_valid = False
    for i in range(len(cmd)):
        if os.path.isfile(cmd[i]):
            conduit = get_conduit(cmd[i])
            if conduit:
                if os.path.basename(cmd[i]) == cmd[i]:
                     cmd[i] = './' + cmd[i]
                executable = cmd[i]
                is_valid = True
                break

    if not is_valid:
        print_error_and_abort('"' + ' '.join(cmd) + '" does not appear to execute a UPC++/GASNet executable')

    if args.info:
        print_env_info(executable)
        return 0
        
    spawner = None
    if args.verbose > 1:
        os.environ['GASNET_VERBOSEENV'] = '1'
        os.environ['UPCXX_VERBOSE'] = '1'

    if args.verbose:
        print('UPCXX_RUN:', executable, 'is compiled with', conduit, 'conduit')

    if args.sshservers and args.localhost:
        print_error_and_abort('Conflicting options: cannot specify both -ssh-servers and -localhost')
        
    if conduit == 'smp':
        os.environ['GASNET_PSHM_NODES'] = str(args.nranks)
        if args.nnodes and args.nnodes != 1:
            print('WARNING: '+executable+' was compiled with smp backend, which only supports single-node operation. -N option ignored.')
            print('WARNING: For multi-node operation, please re-compile with UPCXX_NETWORK=<backend>')
            args.nnodes = 1
    elif conduit == 'udp':
        spawner = 'amudprun'
        cmd = ['-np', str(args.nranks)] + cmd
        if args.localhost:
            os.environ['GASNET_SPAWNFN'] = 'L'
        elif os.environ.get('GASNET_SPAWNFN') in [None , 'S', 's']:
            set_ssh_servers(args.sshservers, args.nranks, conduit)
    elif conduit in ['mpi', 'ibv', 'aries', 'pami', 'ucx', 'ofi']:
        spawner = 'gasnetrun_' + conduit
        cmd = ['-n', str(args.nranks)] + cmd
        if args.nnodes:
            cmd = ['-N', str(args.nnodes)] + cmd
        if conduit in ['ibv', 'ucx']:
            set_ssh_servers(args.sshservers, args.nranks, conduit)
    else:
        print_error_and_abort('Unknown GASNet conduit: ' + conduit)
        
    if conduit != 'udp' and args.localhost:
        print('Ignoring -localhost setting: only applies to UDP conduit, not', conduit)
        
    # FIXME: ucx and ofi intentionally omitted from warning text until we claim support
    if args.sshservers and conduit not in ['udp', 'ibv', 'ucx', 'ofi']:
        print('Ignoring -ssh-servers setting: only applies to UDP and IBV conduits, not', conduit)
        
    if spawner:
        print_verbose('UPCXX_RUN: Looking for spawner "' + spawner + '" in:')
        paths = []
        try:
            gasnet_prefix = os.environ['GASNET_PREFIX']
            paths.append(gasnet_prefix + '/bin')
            paths.append(gasnet_prefix + '/' + conduit + '-conduit/contrib')
            paths.append(gasnet_prefix + '/other/amudp')
        except KeyError:
            pass
        mypath = os.path.dirname(os.path.realpath(__file__))
        paths.append(mypath + '/../gasnet.opt/bin')
        paths.append(mypath + '/../gasnet.debug/bin')
        try:
            upcxx_install = os.environ['UPCXX_INSTALL']
            paths.append(upcxx_install + '/gasnet.opt/bin')
            paths.append(upcxx_install + '/gasnet.debug/bin')
        except KeyError:
            pass
        spawner_path = None
        for path in paths:
            print_verbose('UPCXX_RUN:    ' + path)
            if os.path.exists(path):
                if os.path.exists(path + '/' + spawner):
                    spawner_path = path + '/' + spawner
                    break
        if spawner_path:
            print_verbose('UPCXX_RUN: Found spawner "' + spawner_path + '"')
            spawner = [ spawner_path ]
            if args.verbose > 2:
                spawner += ['-v']
        else:
            print_error_and_abort('Could not find spawner "' + spawner + '"')


    heapkey_new = 'UPCXX_SHARED_HEAP_SIZE'
    heapkey_old = 'UPCXX_SEGMENT_MB'
    if args.heapsz is None:
        heapkey = heapkey_new
        heapval = os.environ.get(heapkey)
        if not heapval and os.environ.get(heapkey_old): # backwards compat
            heapval = os.environ.get(heapkey_old)
            heapkey = heapkey_old
            os.environ[heapkey_new] = os.environ[heapkey_old]

        if heapval:
            provided_upcxx_segsize_strat = upcxx_memstr_parse(heapval)
            if provided_upcxx_segsize_strat == None:
                print_error_and_abort(heapkey + ' setting is wrong. It must either be MAX, or a valid heap size expression.')
            else:
                if os.environ.get('GASNET_MAX_SEGSIZE') is not None:
                     print_verbose('UPCXX_RUN: '+heapkey+' and GASNET_MAX_SEGSIZE have both been provided and are set to '+heapkey+' = '+ \
                       heapval + ' and GASNET_MAX_SEGSIZE = ' + os.environ.get('GASNET_MAX_SEGSIZE') + '.',
                       'Please make sure that the value of GASNET_MAX_SEGSIZE is sufficiently large to accomodate UPC++\'s shared segment')
                elif heapval != "MAX": 
                    gasnet_max_segsize_strat = SegmentStrategy(provided_upcxx_segsize_strat.num ,'')
                    print_verbose('UPCXX_RUN: GASNET_MAX_SEGSIZE is set to', str(gasnet_max_segsize_strat.num / ONE_MB) +'MB/P' )
                    os.environ['GASNET_MAX_SEGSIZE'] = str(gasnet_max_segsize_strat.num / ONE_MB) + 'MB/P'
        else:
            args.heapsz = SegmentStrategy(128*ONE_MB, '')
            os.environ['UPCXX_SHARED_HEAP_SIZE'] = str(args.heapsz.num / ONE_MB) + ' MB'
            if os.environ.get('GASNET_MAX_SEGSIZE') is not None:
              print_verbose('UPCXX_RUN: GASNET_MAX_SEGSIZE has been provided and set to ' + os.environ.get('GASNET_MAX_SEGSIZE') + '.',
                      'Please make sure that this value is sufficiently large to accomodate UPC++\'s shared segment (the default value is 128MB)')
            else:
                os.environ['GASNET_MAX_SEGSIZE'] = str(args.heapsz.num / ONE_MB) + 'MB/P'
    elif args.heapsz.strategy == '':
        os.environ['UPCXX_SHARED_HEAP_SIZE'] = str(args.heapsz.num / ONE_MB) + ' MB'
        os.environ['GASNET_MAX_SEGSIZE'] = str(args.heapsz.num / ONE_MB) + 'MB/P'
        print_verbose('UPCXX_RUN: GASNET_MAX_SEGSIZE is set to', str(args.heapsz.num / ONE_MB) + 'MB/P')
    elif args.heapsz.strategy == r'%':
        os.environ['UPCXX_SHARED_HEAP_SIZE'] = 'MAX'
        if args.heapsz.num > 1.0 or args.heapsz.num <= 0.0:
            print_error_and_abort('shared-heap value has to be greater than 0% and cannot exceed 100%')
        gasnet_max_segsize_strat = SegmentStrategy(args.heapsz.num,'/H')
        print_verbose('UPCXX_RUN: GASNET_MAX_SEGSIZE is set to', str(gasnet_max_segsize_strat.num) +'/H' )
        os.environ['GASNET_MAX_SEGSIZE'] = str(gasnet_max_segsize_strat.num) + '/H'

    if os.environ.get(heapkey_old): # purge the old setting to avoid confusion
        del os.environ[heapkey_old]

    if args.backtrace:
        os.environ['GASNET_BACKTRACE'] = '1'

    # issue #265: ensure automatic env propagation for UPCXX_* variables
    if conduit not in ['udp', 'smp']:
        args.env_list = sum([var.split(',') for var in args.env_list],[]) # parse comma delimiters
        env_list = []
        for env_key in sorted(os.environ):
            if env_key.startswith('UPCXX_') \
            or env_key.startswith('UPC_') \
            or env_key in args.env_list:
                env_list += [env_key]
        if env_list:
            spawner += ['-E', ','.join(env_list)]
    
    # prepend the spawn command, if any
    if spawner:
        cmd = spawner + cmd

    if args.verbose or args.show:
        print('UPCXX_RUN: Environment:')
        for env_key in sorted(os.environ):
            if env_key.startswith('GASNET') or env_key.startswith('UPCXX') or env_key == 'SSH_SERVERS':
                print('UPCXX_RUN:   ', env_key, '=', os.environ[env_key])
        print('UPCXX_RUN: Command line:\n')
        print('    ' + ' '.join(cmd))
        print('')

    if not args.show:
        sys.stdout.flush()
        try:
            os.execvp(cmd[0], cmd)
        except OSError as err:
            print("UPCXX_RUN: OS error executing '" + ' '.join(cmd) +
                  "'\n    {0}".format(err))

        
if __name__ == '__main__':
    main()
